In [1]:
module_path = ['../../dataset_scripts', '../../../caffe/python']

import os
import sys
for path in module_path:
    p = os.path.abspath(os.path.join(path))
    if p not in sys.path:
        sys.path.append(p)
In [2]:
import cv2 as cv
import numpy as np
import scipy
import PIL.Image
import math
import caffe
import time
from config_reader import config_reader
import util
import copy
import matplotlib
%matplotlib inline
import pylab as plt
from scipy.spatial import distance
from scipy.ndimage.filters import gaussian_filter
from scipy.io import loadmat
from numpy import ma
import pickle
import cuhk_large
from sklearn.metrics import average_precision_score, precision_recall_curve
from skimage.feature import hog

Prepare dataset

In [3]:
dataset_root = '../../dataset/cuhk_large/dataset'
dataset = cuhk_large.CUHK_Large(dataset_root)
In [4]:
test_pool_size = dataset.get_test_pool_size()
print('Total test pool size: %d' % (test_pool_size))

query_size = dataset.get_test_query_size()
gallery_size = dataset.get_test_query_gallery_size()
print('Total query size: %d with gallery size: %d' % (query_size, gallery_size))
Total test pool size: 6978
Total query size: 2900 with gallery size: 100

Setup pose extraction model

id key point (heatmap) limb (paf)
0 nose neck -> r-hip (x)
1 neck r-hip -> neck (y)
2 r-shoulder r-hip -> r-knee (x)
3 r-elbow r-knee -> r-hip (y)
4 r-wrist r-knee -> r-ankle (x)
5 l-shoulder r-ankle -> r-knee (y)
6 l-elbow neck -> l-hip (x)
7 l-wrist l-hip -> neck (y)
8 r-hip l-hip -> l-knee (x)
9 r-knee l-knee -> l-hip (y)
10 r-ankle l-knee -> l-ankle (x)
11 l-hip l-ankle -> l-knee (y)
12 l-knee neck -> r-shoulder (x)
13 l-ankle r-shoulder -> neck (y)
14 r-eye r-shoulder -> r-elbow (x)
15 l-eye r-elbow -> r-shoulder (y)
16 r-ear r-elbow -> r-wrist (x)
17 l-ear r-wrist -> r-elbow (y)
18 r-shoulder -> r-ear (x)
19 r-ear -> r-shoulder (y)
20 neck -> l-shoulder (x)
21 l-shoulder -> neck (y)
22 l-shoulder -> l-elbow (x)
23 l-elbow -> l-shoulder (y)
24 l-elbow -> l-wrist (x)
25 l-wrist -> l-elbow (y)
26 l-shoulder -> l-ear (x)
27 l-ear -> l-shoulder (y)
28 neck -> nose (x)
29 nose -> neck (y)
30 nose -> r-eye (x)
31 r-eye -> nose (y)
32 nose -> l-eye (x)
33 l-eye -> nose (y)
34 r-eye -> r-ear (x)
35 r-ear -> r-eye (y)
36 l-eye -> l-ear (x)
37 l-ear -> l-eye (y)
In [5]:
param, model = config_reader()
print(param)
print(model)

net = caffe.Net(model['deployFile'], model['caffemodel'], caffe.TEST)
{'use_gpu': 1, 'GPUdeviceNumber': 0, 'modelID': '1', 'octave': 3, 'starting_range': 0.8, 'ending_range': 2.0, 'scale_search': [0.5, 1.0, 1.5, 2.0], 'thre1': 0.1, 'thre2': 0.05, 'thre3': 0.5, 'min_num': 4, 'mid_num': 10, 'crop_ratio': 2.5, 'bbox_ratio': 0.25}
{'caffemodel': './_trained_COCO/pose_iter_440000.caffemodel', 'deployFile': './_trained_COCO/pose_deploy.prototxt', 'description': 'COCO Pose56 Two-level Linevec', 'boxsize': 368, 'padValue': 128, 'np': '12', 'stride': 8, 'part_str': ['[nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear', 'pt19]']}
In [6]:
key_point_count = 19
limb_count = 38
mid_num = 10

# find connection in the specified sequence, center 29 is in the position 15
limbSeq = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10], \
           [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17], \
           [1,16], [16,18], [3,17], [6,18]]
# the middle joints heatmap correpondence
mapIdx = [[31,32], [39,40], [33,34], [35,36], [41,42], [43,44], [19,20], [21,22], \
          [23,24], [25,26], [27,28], [29,30], [47,48], [49,50], [53,54], [51,52], \
          [55,56], [37,38], [45,46]]
#         [12,13], [20,21], [14,15], [16,17], [22,23], [24,25], [0,1], [2,3],
#         [4,5], [6,7], [8,9], [10,11], [28,29], [30,31], [34,35], [32,33],
#         [36,37], [18,19], [26,27]
In [7]:
def find_heatmap_and_paf(oriImg):
    if param['use_gpu']: 
        caffe.set_mode_gpu()
        caffe.set_device(param['GPUdeviceNumber']) # set to your device!
    else:
        caffe.set_mode_cpu()
    
    multiplier = [x * model['boxsize'] / np.max(oriImg.shape) for x in param['scale_search']]
    
    heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], key_point_count))
    paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], limb_count))
    
    for m in range(len(multiplier)):
        scale = multiplier[m]
        imageToTest = cv.resize(oriImg, (0,0), fx=scale, fy=scale, interpolation=cv.INTER_CUBIC)
        imageToTest_padded, pad = util.padRightDownCorner(imageToTest, model['stride'], model['padValue'])
#         print(imageToTest_padded.shape)
        
        net.blobs['data'].reshape(*(1, 3, imageToTest_padded.shape[0], imageToTest_padded.shape[1]))
        #net.forward() # dry run
        net.blobs['data'].data[...] = np.transpose(np.float32(imageToTest_padded[:,:,:,np.newaxis]), (3,2,0,1))/256 - 0.5;
        start_time = time.time()
        output_blobs = net.forward()
#         print('At scale %d, The CNN took %.2f ms.' % (m, 1000 * (time.time() - start_time)))
        
        # extract outputs, resize, and remove padding
        heatmap = np.transpose(np.squeeze(net.blobs[list(output_blobs.keys())[1]].data), (1,2,0)) # output 1 is heatmaps
        heatmap = cv.resize(heatmap, (0,0), fx=model['stride'], fy=model['stride'], interpolation=cv.INTER_CUBIC)
        heatmap = heatmap[:imageToTest_padded.shape[0]-pad[2], :imageToTest_padded.shape[1]-pad[3], :]
        heatmap = cv.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv.INTER_CUBIC)

        paf = np.transpose(np.squeeze(net.blobs[list(output_blobs.keys())[0]].data), (1,2,0)) # output 0 is PAFs
        paf = cv.resize(paf, (0,0), fx=model['stride'], fy=model['stride'], interpolation=cv.INTER_CUBIC)
        paf = paf[:imageToTest_padded.shape[0]-pad[2], :imageToTest_padded.shape[1]-pad[3], :]
        paf = cv.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv.INTER_CUBIC)
        
        heatmap_avg = heatmap_avg + heatmap / len(multiplier)
        paf_avg = paf_avg + paf / len(multiplier)
        
    return (heatmap_avg, paf_avg)

def find_all_peaks(heatmap_avg):
    all_peaks = []
    peak_counter = 0

    for part in range(key_point_count-1):
        x_list = []
        y_list = []
        map_ori = heatmap_avg[:,:,part]
        map = gaussian_filter(map_ori, sigma=3)

        map_left = np.zeros(map.shape)
        map_left[1:,:] = map[:-1,:]
        map_right = np.zeros(map.shape)
        map_right[:-1,:] = map[1:,:]
        map_up = np.zeros(map.shape)
        map_up[:,1:] = map[:,:-1]
        map_down = np.zeros(map.shape)
        map_down[:,:-1] = map[:,1:]

        peaks_binary = np.logical_and.reduce((map>=map_left, map>=map_right, map>=map_up, map>=map_down, map > param['thre1']))
        peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
        peaks_with_score = [x + (map_ori[x[1],x[0]],) for x in peaks]
        id = range(peak_counter, peak_counter + len(peaks))
        peaks_with_score_and_id = [peaks_with_score[i] + (id[i],) for i in range(len(id))]

        all_peaks.append(peaks_with_score_and_id)
        peak_counter += len(peaks)
    
    return (all_peaks, peak_counter)

def find_subset_and_candidate(oriImg, paf_avg, all_peaks):
    connection_all = []
    special_k = []
    
    for k in range(len(mapIdx)):
        score_mid = paf_avg[:,:,[x-key_point_count for x in mapIdx[k]]]
        candA = all_peaks[limbSeq[k][0]-1]
        candB = all_peaks[limbSeq[k][1]-1]
        nA = len(candA)
        nB = len(candB)
        indexA, indexB = limbSeq[k]
        if(nA != 0 and nB != 0):
            connection_candidate = []
            for i in range(nA):
                for j in range(nB):
                    vec = np.subtract(candB[j][:2], candA[i][:2])
                    norm = math.sqrt(vec[0]*vec[0] + vec[1]*vec[1])
                    
                    if norm <= 0:
                        continue
                    
                    vec = np.divide(vec, norm)

                    startend = zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
                                   np.linspace(candA[i][1], candB[j][1], num=mid_num))

                    vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
                                      for I in range(len(startend))])
                    vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
                                      for I in range(len(startend))])

                    score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
                    score_with_dist_prior = sum(score_midpts)/len(score_midpts) + min(0.5*oriImg.shape[0]/norm-1, 0)
                    criterion1 = len(np.nonzero(score_midpts > param['thre2'])[0]) > 0.8 * len(score_midpts)
                    criterion2 = score_with_dist_prior > 0
                    if criterion1 and criterion2:
                        connection_candidate.append([i, j, score_with_dist_prior, score_with_dist_prior+candA[i][2]+candB[j][2]])

            connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
            connection = np.zeros((0,5))
            for c in range(len(connection_candidate)):
                i,j,s = connection_candidate[c][0:3]
                if(i not in connection[:,3] and j not in connection[:,4]):
                    connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
                    if(len(connection) >= min(nA, nB)):
                        break

            connection_all.append(connection)
        else:
            special_k.append(k)
            connection_all.append([])
            
    # last number in each row is the total parts number of that person
    # the second last number in each row is the score of the overall configuration
    subset = -1 * np.ones((0, 20))
    candidate = np.array([item for sublist in all_peaks for item in sublist])
    
    for k in range(len(mapIdx)):
        if k not in special_k:
            partAs = connection_all[k][:,0]
            partBs = connection_all[k][:,1]
            indexA, indexB = np.array(limbSeq[k]) - 1

            for i in range(len(connection_all[k])): #= 1:size(temp,1)
                found = 0
                subset_idx = [-1, -1]
                for j in range(len(subset)): #1:size(subset,1):
                    if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
                        subset_idx[found] = j
                        found += 1

                if found == 1:
                    j = subset_idx[0]
                    if(subset[j][indexB] != partBs[i]):
                        subset[j][indexB] = partBs[i]
                        subset[j][-1] += 1
                        subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
                elif found == 2: # if found 2 and disjoint, merge them
                    j1, j2 = subset_idx
#                     print "found = 2"
                    membership = ((subset[j1]>=0).astype(int) + (subset[j2]>=0).astype(int))[:-2]
                    if len(np.nonzero(membership == 2)[0]) == 0: #merge
                        subset[j1][:-2] += (subset[j2][:-2] + 1)
                        subset[j1][-2:] += subset[j2][-2:]
                        subset[j1][-2] += connection_all[k][i][2]
                        subset = np.delete(subset, j2, 0)
                    else: # as like found == 1
                        subset[j1][indexB] = partBs[i]
                        subset[j1][-1] += 1
                        subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]

                # if find no partA in the subset, create a new subset
                elif not found and k < 17:
                    row = -1 * np.ones(20)
                    row[indexA] = partAs[i]
                    row[indexB] = partBs[i]
                    row[-1] = 2
                    row[-2] = sum(candidate[connection_all[k][i,:2].astype(int), 2]) + connection_all[k][i][2]
                    subset = np.vstack([subset, row])
                    
    # delete some rows of subset which has few parts occur
    deleteIdx = [];
    for k in range(len(subset)):
        if subset[k][-1] < 4 or subset[k][-2]/subset[k][-1] < 0.4:
            deleteIdx.append(k)
    subset = np.delete(subset, deleteIdx, axis=0)
    
    return (subset, candidate)

Find bounding boxes

In [8]:
body_types = ['face', 'up', 'low', 'body']
# 'face', 'up', 'low'
body_index = [[0, 14, 15, 16, 17], [2, 3, 4, 5, 6, 7, 8, 11], [8, 9, 10, 11, 12, 13]]
body_indices = dict(zip(body_types, body_index))
        
def box_extend(box, ratio, limit):
    l, r, t, b = box
    dx = r - l
    dy = b - t
    l = int(max(l-ratio*dx, 0))
    r = int(min(r+ratio*dx, limit[0]))
    t = int(max(t-ratio*dy, 0))
    b = int(min(b+ratio*dy, limit[1]))
    return [l, r, t, b]
        
def find_person_bounding_box(image_name):
    oriImg = cv.imread(dataset.get_image_path(image_name))
    heatmap_avg, paf_avg = find_heatmap_and_paf(oriImg)
    all_peaks, _ = find_all_peaks(heatmap_avg)
    subset, candidate = find_subset_and_candidate(oriImg, paf_avg, all_peaks)
    persons = []
    
    for n in range(len(subset)):
        bboxes = dict(zip(body_types, [[]] * len(body_types)))
        bboxes['body'] = [oriImg.shape[0], 0, oriImg.shape[1], 0]
        for part in body_indices:
            l = r = t = b = 0
            index = subset[n][np.array(body_indices[part])]
            index = np.delete(index, np.argwhere(index==-1))
            if len(index) != 0:
                Y = candidate[index.astype(int), 0]
                X = candidate[index.astype(int), 1]
                l = int(np.min(X))
                r = int(np.max(X))
                t = int(np.min(Y))
                b = int(np.max(Y))
                if part == 'face':
                    cx = int(np.mean(X))
                    cy = int(np.mean(Y))
                    d = np.mean([r-l, b-t])
                    l = int(cx - d)
                    r = int(cx + d)
                    t = int(cy - d)
                    b = int(cy + d)

            if part != 'face':
                l, r, t, b = box_extend([l, r, t, b], 0.2, (oriImg.shape[0], oriImg.shape[1]))
                
            bboxes[part] = [l, r, t, b]
            if l > 0 or r > 0 or t > 0 or b > 0:
                bboxes['body'][0] = min(bboxes['body'][0], l)
                bboxes['body'][1] = max(bboxes['body'][1], r)
                bboxes['body'][2] = min(bboxes['body'][2], t)
                bboxes['body'][3] = max(bboxes['body'][3], b)

        persons.append(bboxes)
    
    return persons
In [9]:
pool_image_bboxes_name = 'cuhklarge-pool-bboxes.pickle'
pool_image_bboxes = {}

pool_image_bboxes_path = os.path.abspath(os.path.join(pool_image_bboxes_name))
if os.path.exists(pool_image_bboxes_path):
    with open(pool_image_bboxes_path, 'rb') as f:
        pool_image_bboxes = pickle.load(f)
    print('Load pool image bboxes from pickle')
else:
    start_time = time.time()
    for i in range(test_pool_size):
        image_name = dataset. get_test_image_name(i)
        pool_image_bboxes[image_name] = find_person_bounding_box(image_name)
    with open(pool_image_bboxes_path, 'wb') as f:
        pickle.dump(pool_image_bboxes, f)
    print('Total time to find bounding box: %.2f s.' % ((time.time() - start_time)))
    
del net
Load pool image bboxes from pickle

Setup feature extractor

In [10]:
n_bins = 16

def calculate_feature_vector(image):
    image = cv.cvtColor(image, cv.COLOR_BGR2HSV)
    hist = cv.calcHist([image], [2, 1, 0], None, [n_bins, n_bins, n_bins], [0, 181, 0, 256, 0, 256])
    hist = cv.normalize(hist, dst=hist.shape).flatten()
    return hist

def create_features(oriImg, bboxes):
    vectors = dict(zip(body_types, [[]] * len(body_types)))
    for part in bboxes:
        if part == 'body':
            continue
        if len(bboxes[part]) == 0:
            continue
        l, r, t, b = bboxes[part]
        img = oriImg[l:r,t:b,:].copy()
        if np.size(img) == 0:
            continue
        vectors[part] = []
        vectors[part].append(calculate_feature_vector(img))
    return vectors
In [11]:
pool_image_vectors_name = 'cuhklarge-pool-vectors-hsv.pickle'
pool_image_vectors = {}

pool_image_vectors_path = os.path.abspath(os.path.join(pool_image_vectors_name))
if os.path.exists(pool_image_vectors_path):
    with open(pool_image_vectors_path, 'rb') as f:
        pool_image_vectors = pickle.load(f)
    print('Load pool image vectors from pickle')
else:
    start_time = time.time()
    for image_name in pool_image_bboxes:
        oriImg = cv.imread(dataset.get_image_path(image_name))
        size = len(pool_image_bboxes[image_name])
        pool_image_vectors[image_name] = [[]] * size
        for i in range(size):
            pool_image_vectors[image_name][i] = create_features(oriImg, pool_image_bboxes[image_name][i])
    with open(pool_image_vectors_path, 'wb') as f:
        pickle.dump(pool_image_vectors, f)
    print('Total time to create feature vectors: %.2f s.' % ((time.time() - start_time)))
Load pool image vectors from pickle

Prepare query and gallery data

In [12]:
def calculate_iou(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    li = max(boxA[0], boxB[0])
    ri = min(boxA[1], boxB[1])
    ti = max(boxA[2], boxB[2])
    bi = min(boxA[3], boxB[3])
    
    if li > ri or ti > bi:
        return 0.0

    # compute the area of intersection rectangle
    interArea = (ri - li + 1) * (bi - ti + 1)

    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[1] - boxA[0] + 1) * (boxA[3] - boxA[2] + 1)
    boxBArea = (boxB[1] - boxB[0] + 1) * (boxB[3] - boxB[2] + 1)

    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)

    # return the intersection over union value
    return iou

def calculate_ioA(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    li = max(boxA[0], boxB[0])
    ri = min(boxA[1], boxB[1])
    ti = max(boxA[2], boxB[2])
    bi = min(boxA[3], boxB[3])
    
    if li > ri or ti > bi:
        return 0.0

    # compute the area of intersection rectangle
    interArea = (ri - li + 1) * (bi - ti + 1)

    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[1] - boxA[0] + 1) * (boxA[3] - boxA[2] + 1)

    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    ioA = interArea / float(boxAArea)

    # return the intersection over union value
    return ioA

def calculate_ioB(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    li = max(boxA[0], boxB[0])
    ri = min(boxA[1], boxB[1])
    ti = max(boxA[2], boxB[2])
    bi = min(boxA[3], boxB[3])
    
    if li > ri or ti > bi:
        return 0.0

    # compute the area of intersection rectangle
    interArea = (ri - li + 1) * (bi - ti + 1)
    
    boxBArea = (boxB[1] - boxB[0] + 1) * (boxB[3] - boxB[2] + 1)

    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    ioB = interArea / float(boxBArea)

    # return the intersection over union value
    return ioB

def is_bounding_box_match_ground_truth(bbox, gtbox, threshold=0.5):
    ioA = calculate_ioA(bbox, gtbox)
    ioB = calculate_ioB(bbox, gtbox)
    return ioA >= threshold or ioB >= threshold
In [13]:
def draw_ground_truth_box(oriImg, gtbox):
    if len(gtbox) == 0:
        return
    
    l, r, t, b  = gtbox
    cv.rectangle(oriImg, (t,l), (b,r), (0,0,255), 2)
    
def draw_found_bounding_box(oriImg, bbox):
    l, r, t, b  = bbox['face']
    cv.rectangle(oriImg, (t,l), (b,r), (0,255,0), 2)
    l, r, t, b  = bbox['up']
    cv.rectangle(oriImg, (t,l), (b,r), (0,255,0), 2)
    l, r, t, b  = bbox['low']
    cv.rectangle(oriImg, (t,l), (b,r), (0,255,0), 2)
    l, r, t, b  = bbox['body']
    cv.rectangle(oriImg, (t,l), (b,r), (255,0,0), 2)

class GalleryObject:
    image_name = None
    gtbox = None
    
    def __init__(self, image_name, gtbox):
        self.image_name = image_name
        self.gtbox = gtbox
        
    def get_found_bboxes(self):
        return pool_image_bboxes[self.image_name]
    
    def get_found_bbox(self, index):
        return self.get_found_bboxes()[index]
    
    def get_found_vectors(self):
        return pool_image_vectors[self.image_name]
    
    def get_found_vector(self, index):
        return self.get_found_vectors()[index]
        
    def draw(self):
        f, axarr = plt.subplots(1, 1)
        f.set_size_inches((20, 20))
        
        oriImg = cv.imread(dataset.get_image_path(self.image_name))
        draw_ground_truth_box(oriImg, self.gtbox)
        found_bboxes = self.get_found_bboxes()
        for bbox in found_bboxes:
            draw_found_bounding_box(oriImg, bbox)
        axarr.imshow(oriImg[:,:,[2,1,0]])
        axarr.set_title(self.image_name)

class QueryObject:
    image_name = None
    gtbox = None
    matched_query_id = None
    
    def __init__(self, image_name, gtbox):
        self.image_name = image_name
        self.gtbox = gtbox
        found_bboxes = self.get_found_bboxes()
        for i in range(len(found_bboxes)):
            bbox = found_bboxes[i]
            if is_bounding_box_match_ground_truth(bbox['body'], self.gtbox):
                self.matched_query_id = i
                break
                
    def get_found_bboxes(self):
        return pool_image_bboxes[self.image_name]
    
    def get_query_bbox(self):
        if self.matched_query_id is None:
            return None
        else:
            return pool_image_bboxes[self.image_name][self.matched_query_id]
    
    def get_found_vectors(self):
        return pool_image_vectors[self.image_name]
    
    def get_query_vector(self):
        if self.matched_query_id is None:
            return None
        else:
            return pool_image_vectors[self.image_name][self.matched_query_id]
                
    def draw(self):
        f, axarr = plt.subplots(1, 2)
        f.set_size_inches((20, 20))
        
        oriImg = cv.imread(dataset.get_image_path(self.image_name))
        found_bboxes = self.get_found_bboxes()
        for bbox in found_bboxes:
            draw_found_bounding_box(oriImg, bbox)
        axarr[0].imshow(oriImg[:,:,[2,1,0]])
        axarr[0].set_title(self.image_name)
        
        oriImg = cv.imread(dataset.get_image_path(self.image_name))
        draw_ground_truth_box(oriImg, self.gtbox)
        bbox = self.get_query_bbox()
        if bbox is not None:
            draw_found_bounding_box(oriImg, bbox)
        axarr[1].imshow(oriImg[:,:,[2,1,0]])
        axarr[1].set_title(self.image_name)
In [14]:
# ['face', 'up', 'low', 'body']
similarity_weight = dict(zip(body_types, [0.0, 1.0, 1.0]))

def find_similarity(query_vector, gallery_vector):
    similarity = dict(zip(body_types, [0] * len(body_types)))
    for part in body_types:
        if part == 'body':
            continue
        
        if np.sum(query_vector[part]) == 0 or np.sum(gallery_vector[part]) == 0:
            similarity[part] = None
        else:
            similarity[part] = 0.0
            methods = len(query_vector[part])
            for m in range(methods):
                similarity[part] += 1.0 - distance.cosine(query_vector[part][m], gallery_vector[part][m])
            similarity[part] /= methods
    
    return similarity

def find_final_similarity(similarity):
    result = 0
    count = 0
    for part in similarity_weight:
        if similarity[part] is not None:
            result += similarity[part] * similarity_weight[part]
            count += similarity_weight[part]
           
    if count > 0:
        result /= count

    return result
In [15]:
class TestSample:
    query_obj = None
    gallery_objs = None
    ap = None
    acc = None
    result = None
    
    def __init__(self, index):
        query_data = dataset.get_test_query_query_data(index)
        self.query_obj = QueryObject(query_data.imname, query_data.idlocate)
        
        self.gallery_objs = []
        for i in range(gallery_size):
            gallery_data = dataset.get_test_query_gallery_data(index, i)
            self.gallery_objs.append(GalleryObject(gallery_data.imname, gallery_data.idlocate))
            
    def _evaluate(self, topk):
        y_true = []
        y_score = []
        imgs = []
        rois = []
        count_gt = 0
        count_tp = 0
        for gobj in self.gallery_objs:
            image_name = gobj.image_name
            gt = np.array(gobj.gtbox, dtype=np.int32)
            count_gt += (gt.size > 0)
            
            found_number = len(gobj.get_found_vectors())
            bbox = np.zeros((found_number, 4), dtype=np.int32)
            sim = np.zeros(found_number, dtype=np.float)
            label = np.zeros(found_number, dtype=np.int32)
            
            for i in range(found_number):
                s = find_similarity(self.query_obj.get_query_vector(), gobj.get_found_vector(i))
                sim[i] = find_final_similarity(s)
                bbox[i,:] = np.array(gobj.get_found_bbox(i)['body'])
            
            if gt.size > 0:
                inds = np.argsort(sim)[::-1]
                sim = sim[inds]
                bbox = bbox[inds]
                for j in range(found_number):
                    if is_bounding_box_match_ground_truth(bbox[j], gobj.gtbox):
                        label[j] = 1
                        count_tp += 1
                        
            y_true.extend(list(label))
            y_score.extend(list(sim))
            imgs.extend([image_name] * found_number)
            rois.extend(list(bbox))
        
        y_true = np.asarray(y_true)
        y_score = np.asarray(y_score)
        recall_rate = float(count_tp) / count_gt
        self.ap = 0 if count_tp == 0 else average_precision_score(y_true, y_score) * recall_rate
        inds = np.argsort(y_score)[::-1]
        y_score = y_score[inds]
        y_true = y_true[inds]
        self.acc = [min(1, sum(y_true[:k])) for k in topk]
        
        self.result = []
        for k in range(10):
            self.result.append(
            {
                'image_name': str(imgs[inds[k]]),
                'bbox': rois[inds[k]],
                'score': float(y_score[k]),
                'correct': int(y_true[k])
            })
    
    def evaluate(self, topk):
        if self.query_obj.matched_query_id is None:
            self.ap = np.nan
            self.acc = [np.nan] * len(topk)
        else:
            self._evaluate(topk)
        

Evaluate similarity

In [16]:
topk = [1, 5, 10]
In [17]:
test_samples = []
aps = []
accs = []
aps_nan = []
accs_nan = []

for i in range(query_size):
    sample = TestSample(i)
    test_samples.append(sample)
    sample.evaluate(topk)
    ap = sample.ap
    acc = sample.acc
    
    aps_nan.append(ap)
    accs_nan.append(acc)
    
    if ap is np.nan:
        ap = 0.0
        acc = [0.0] * len(topk)
    aps.append(ap)
    accs.append(acc)
In [18]:
print('mAP = {:.2%}'.format(np.mean(aps)))
avg_accs = np.mean(accs, axis=0)
for i, k in enumerate(topk):
    print('top-{:2d} = {:.2%}'.format(k, avg_accs[i]))
mAP = 46.43%
top- 1 = 42.79%
top- 5 = 60.03%
top-10 = 67.38%
In [19]:
print('Remove no query bounding box')
print('mAP = {:.2%}'.format(np.nanmean(aps_nan)))
avg_accs_nan = np.nanmean(accs_nan, axis=0)
for i, k in enumerate(topk):
    print('top-{:2d} = {:.2%}'.format(k, avg_accs_nan[i]))
Remove no query bounding box
mAP = 47.66%
top- 1 = 43.93%
top- 5 = 61.63%
top-10 = 69.17%
In [20]:
nan_index = np.argwhere(np.isnan(aps_nan)).ravel()
print('No bounding box sample count: {}'.format(nan_index.size))
print('No bounding box sample index: \n{}'.format(nan_index))
No bounding box sample count: 75
No bounding box sample index: 
[  66  120  121  183  203  269  328  369  380  385  460  486  493  522  532
  570  571  576  590  602  650  693  726  790  913  967  991 1078 1079 1086
 1099 1140 1147 1213 1227 1263 1290 1315 1333 1378 1387 1490 1600 1633 1635
 1655 1704 1721 1743 1744 1864 1869 1893 1915 1938 1962 1967 2054 2095 2107
 2182 2232 2263 2272 2280 2329 2389 2390 2477 2626 2689 2746 2834 2849 2852]
In [21]:
miss_index = []
for i in range(query_size):
    if test_samples[i].acc[0] == 0 and test_samples[i].acc[1] == 1:
        miss_index.append(i)
        
print(len(miss_index))
print(miss_index)
500
[7, 23, 24, 36, 37, 43, 47, 60, 76, 80, 81, 82, 84, 91, 93, 99, 105, 122, 125, 136, 144, 147, 151, 153, 167, 168, 179, 181, 184, 187, 192, 213, 218, 220, 231, 235, 239, 242, 247, 249, 250, 261, 277, 292, 293, 295, 310, 318, 324, 334, 336, 337, 340, 341, 344, 348, 368, 393, 395, 399, 433, 438, 445, 469, 471, 475, 476, 477, 480, 485, 497, 503, 516, 519, 521, 530, 531, 542, 543, 544, 545, 557, 560, 585, 606, 611, 637, 653, 656, 659, 665, 689, 694, 696, 719, 720, 723, 731, 739, 743, 750, 755, 767, 769, 786, 795, 803, 810, 822, 826, 832, 833, 835, 853, 862, 873, 878, 883, 886, 908, 909, 917, 922, 931, 936, 939, 941, 943, 944, 947, 952, 954, 958, 961, 962, 963, 973, 979, 987, 988, 1002, 1005, 1011, 1014, 1037, 1039, 1042, 1061, 1069, 1080, 1087, 1092, 1095, 1096, 1118, 1125, 1127, 1138, 1143, 1146, 1151, 1153, 1155, 1161, 1162, 1178, 1179, 1180, 1184, 1192, 1203, 1205, 1209, 1215, 1217, 1218, 1225, 1230, 1242, 1245, 1252, 1256, 1265, 1274, 1277, 1282, 1284, 1288, 1292, 1293, 1299, 1305, 1306, 1320, 1326, 1329, 1335, 1336, 1340, 1348, 1349, 1356, 1364, 1366, 1369, 1381, 1396, 1397, 1407, 1426, 1427, 1429, 1434, 1435, 1439, 1440, 1447, 1452, 1454, 1455, 1467, 1476, 1511, 1514, 1518, 1519, 1521, 1527, 1529, 1543, 1547, 1560, 1576, 1583, 1587, 1589, 1592, 1593, 1598, 1604, 1613, 1623, 1630, 1645, 1646, 1649, 1654, 1657, 1662, 1666, 1667, 1674, 1675, 1679, 1683, 1685, 1687, 1689, 1692, 1710, 1711, 1714, 1719, 1733, 1748, 1758, 1762, 1763, 1764, 1772, 1776, 1781, 1784, 1788, 1797, 1810, 1813, 1825, 1832, 1846, 1850, 1851, 1858, 1863, 1865, 1880, 1883, 1888, 1894, 1897, 1898, 1900, 1911, 1913, 1922, 1923, 1927, 1932, 1934, 1937, 1941, 1946, 1952, 1954, 1957, 1960, 1961, 1972, 1975, 1976, 1982, 1995, 1999, 2002, 2004, 2020, 2025, 2027, 2028, 2029, 2038, 2039, 2040, 2056, 2068, 2069, 2081, 2085, 2087, 2094, 2112, 2115, 2122, 2123, 2124, 2126, 2127, 2130, 2136, 2152, 2161, 2172, 2174, 2180, 2191, 2193, 2195, 2202, 2205, 2219, 2224, 2226, 2231, 2237, 2240, 2247, 2248, 2257, 2259, 2262, 2265, 2266, 2269, 2275, 2289, 2291, 2296, 2297, 2302, 2304, 2309, 2310, 2314, 2318, 2321, 2324, 2326, 2332, 2335, 2342, 2347, 2354, 2356, 2357, 2360, 2361, 2365, 2367, 2372, 2374, 2376, 2388, 2392, 2396, 2399, 2402, 2406, 2419, 2421, 2422, 2423, 2427, 2431, 2433, 2437, 2438, 2443, 2445, 2452, 2459, 2480, 2481, 2500, 2509, 2517, 2518, 2537, 2539, 2540, 2552, 2560, 2562, 2563, 2564, 2568, 2574, 2576, 2580, 2584, 2585, 2588, 2589, 2596, 2608, 2616, 2618, 2621, 2630, 2635, 2639, 2640, 2657, 2660, 2661, 2662, 2664, 2665, 2675, 2676, 2678, 2683, 2684, 2691, 2694, 2696, 2700, 2706, 2707, 2711, 2717, 2724, 2725, 2727, 2730, 2743, 2754, 2762, 2765, 2769, 2772, 2774, 2775, 2777, 2780, 2782, 2787, 2790, 2794, 2797, 2809, 2813, 2817, 2820, 2832, 2836, 2843, 2844, 2847, 2856, 2859, 2862, 2868, 2869, 2875, 2878, 2879, 2882, 2891, 2892, 2898]
In [22]:
lbp_hsv_miss_index = [  82,   93,  122,  247,  261,  336,  340,  344,  469,  516,  519,
        545,  557,  659,  689,  826,  873,  878,  936,  941,  954,  979,
       1002, 1095, 1127, 1209, 1217, 1225, 1252, 1366, 1381, 1396, 1547,
       1604, 1630, 1646, 1710, 1733, 1748, 1788, 1911, 1937, 1941, 1972,
       1999, 2004, 2020, 2068, 2124, 2127, 2152, 2219, 2231, 2275, 2291,
       2321, 2332, 2374, 2431, 2437, 2452, 2517, 2576, 2580, 2635, 2660,
       2696, 2762, 2774, 2820, 2836, 2862]
In [23]:
correct_index = []
for i in range(query_size):
    if test_samples[i].acc[0] == 1:
        correct_index.append(i)
        
print(len(correct_index))
print(correct_index)
1241
[0, 1, 3, 4, 5, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 39, 41, 42, 44, 45, 46, 48, 50, 51, 52, 53, 55, 56, 57, 58, 62, 63, 67, 68, 69, 70, 72, 74, 75, 77, 78, 79, 85, 87, 88, 90, 92, 94, 95, 97, 98, 101, 102, 103, 104, 106, 107, 108, 109, 110, 112, 114, 116, 118, 119, 123, 126, 128, 129, 130, 131, 132, 134, 135, 137, 139, 140, 143, 145, 149, 150, 152, 154, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 169, 170, 171, 172, 173, 174, 175, 176, 178, 180, 186, 190, 191, 193, 199, 200, 201, 202, 204, 205, 207, 208, 211, 212, 215, 216, 217, 219, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 232, 233, 236, 237, 238, 240, 241, 243, 248, 251, 252, 254, 255, 256, 257, 258, 260, 262, 265, 266, 267, 268, 270, 271, 272, 273, 275, 278, 280, 281, 283, 284, 285, 286, 287, 288, 289, 290, 291, 294, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 311, 312, 313, 314, 315, 317, 319, 321, 322, 323, 325, 326, 329, 330, 331, 333, 335, 339, 342, 343, 345, 346, 347, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 364, 365, 366, 367, 371, 373, 374, 376, 378, 379, 381, 382, 384, 386, 388, 389, 390, 391, 392, 394, 396, 397, 398, 400, 401, 403, 405, 406, 407, 409, 410, 411, 412, 413, 415, 416, 417, 418, 419, 420, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 434, 435, 439, 440, 441, 442, 443, 444, 447, 448, 449, 451, 452, 453, 454, 455, 456, 458, 461, 462, 463, 464, 465, 466, 467, 468, 473, 478, 479, 483, 484, 487, 489, 491, 492, 494, 495, 496, 498, 499, 501, 502, 504, 506, 507, 508, 509, 512, 513, 514, 515, 518, 520, 523, 524, 526, 527, 528, 529, 533, 534, 535, 537, 539, 540, 546, 548, 549, 550, 551, 552, 554, 556, 559, 561, 562, 563, 565, 566, 568, 569, 572, 575, 577, 578, 579, 580, 582, 583, 584, 586, 589, 591, 592, 593, 594, 595, 596, 597, 599, 600, 601, 603, 604, 609, 610, 612, 614, 615, 617, 618, 619, 621, 623, 624, 625, 626, 627, 628, 629, 632, 633, 634, 635, 636, 638, 641, 642, 643, 644, 645, 647, 648, 651, 652, 654, 657, 658, 660, 661, 662, 664, 666, 667, 668, 669, 671, 672, 673, 675, 677, 680, 681, 682, 684, 686, 687, 690, 691, 692, 695, 697, 700, 701, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 715, 716, 717, 718, 721, 722, 724, 725, 727, 729, 733, 734, 738, 740, 741, 742, 745, 746, 747, 748, 752, 753, 754, 756, 757, 760, 761, 762, 763, 764, 765, 766, 770, 771, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 788, 789, 792, 793, 796, 798, 800, 802, 804, 805, 806, 808, 811, 813, 814, 815, 816, 818, 819, 820, 821, 824, 825, 829, 830, 831, 834, 836, 837, 838, 839, 840, 842, 844, 847, 848, 849, 850, 851, 856, 857, 858, 859, 860, 863, 865, 866, 867, 868, 869, 870, 871, 874, 876, 879, 880, 881, 882, 884, 885, 887, 891, 892, 893, 894, 895, 896, 897, 898, 899, 901, 904, 906, 907, 918, 920, 921, 923, 927, 928, 933, 935, 938, 946, 950, 955, 956, 964, 968, 969, 974, 976, 980, 983, 985, 986, 992, 993, 995, 998, 999, 1000, 1003, 1004, 1006, 1008, 1009, 1010, 1013, 1016, 1017, 1018, 1021, 1027, 1029, 1030, 1033, 1034, 1035, 1038, 1044, 1048, 1050, 1052, 1054, 1055, 1056, 1057, 1058, 1062, 1063, 1067, 1068, 1070, 1072, 1073, 1075, 1077, 1082, 1084, 1085, 1088, 1093, 1101, 1102, 1103, 1104, 1106, 1107, 1111, 1115, 1121, 1122, 1123, 1128, 1129, 1130, 1133, 1134, 1136, 1139, 1145, 1149, 1150, 1154, 1158, 1164, 1165, 1170, 1171, 1173, 1186, 1189, 1196, 1197, 1198, 1201, 1202, 1204, 1207, 1211, 1216, 1219, 1222, 1223, 1226, 1228, 1231, 1233, 1234, 1237, 1238, 1246, 1247, 1254, 1255, 1260, 1262, 1266, 1270, 1273, 1278, 1279, 1280, 1283, 1287, 1291, 1307, 1308, 1309, 1311, 1312, 1318, 1323, 1324, 1325, 1331, 1332, 1334, 1338, 1343, 1346, 1350, 1351, 1357, 1358, 1359, 1360, 1365, 1367, 1371, 1372, 1374, 1377, 1379, 1382, 1388, 1389, 1390, 1391, 1392, 1393, 1395, 1398, 1400, 1401, 1403, 1404, 1408, 1416, 1417, 1419, 1420, 1421, 1422, 1428, 1430, 1431, 1432, 1436, 1437, 1442, 1443, 1444, 1446, 1449, 1462, 1463, 1469, 1470, 1473, 1477, 1478, 1480, 1482, 1483, 1484, 1485, 1486, 1488, 1491, 1492, 1497, 1499, 1500, 1503, 1504, 1505, 1526, 1528, 1530, 1535, 1538, 1539, 1540, 1542, 1544, 1545, 1546, 1548, 1561, 1562, 1567, 1569, 1570, 1573, 1577, 1582, 1585, 1586, 1590, 1591, 1594, 1601, 1603, 1607, 1609, 1610, 1614, 1617, 1620, 1624, 1625, 1627, 1628, 1638, 1640, 1641, 1643, 1647, 1651, 1656, 1659, 1660, 1661, 1664, 1665, 1668, 1669, 1670, 1671, 1677, 1686, 1691, 1693, 1694, 1697, 1698, 1699, 1702, 1703, 1705, 1707, 1709, 1712, 1713, 1718, 1722, 1724, 1725, 1726, 1727, 1730, 1732, 1735, 1736, 1737, 1738, 1742, 1745, 1749, 1750, 1751, 1753, 1760, 1761, 1766, 1768, 1769, 1774, 1775, 1777, 1778, 1779, 1780, 1782, 1783, 1785, 1791, 1793, 1794, 1796, 1798, 1799, 1802, 1803, 1806, 1814, 1815, 1816, 1817, 1819, 1820, 1826, 1828, 1830, 1831, 1834, 1835, 1837, 1838, 1841, 1842, 1843, 1847, 1848, 1854, 1855, 1857, 1859, 1878, 1879, 1881, 1882, 1884, 1891, 1895, 1896, 1901, 1907, 1910, 1917, 1918, 1920, 1926, 1929, 1939, 1942, 1944, 1948, 1949, 1955, 1968, 1970, 1977, 1981, 1983, 1988, 1989, 1994, 2001, 2003, 2005, 2006, 2007, 2009, 2011, 2013, 2014, 2016, 2017, 2018, 2022, 2023, 2024, 2030, 2031, 2034, 2037, 2041, 2043, 2045, 2049, 2050, 2053, 2055, 2060, 2061, 2065, 2066, 2067, 2072, 2073, 2076, 2077, 2078, 2080, 2089, 2091, 2093, 2098, 2099, 2100, 2105, 2111, 2113, 2116, 2118, 2120, 2121, 2128, 2132, 2133, 2134, 2138, 2143, 2145, 2147, 2153, 2155, 2156, 2157, 2159, 2160, 2170, 2171, 2175, 2176, 2179, 2184, 2185, 2188, 2190, 2201, 2204, 2206, 2210, 2211, 2212, 2213, 2214, 2215, 2221, 2223, 2225, 2228, 2233, 2236, 2242, 2243, 2244, 2246, 2250, 2252, 2254, 2267, 2271, 2273, 2277, 2279, 2281, 2283, 2284, 2285, 2286, 2290, 2292, 2295, 2298, 2299, 2303, 2305, 2311, 2312, 2315, 2317, 2319, 2322, 2331, 2333, 2334, 2341, 2344, 2351, 2352, 2353, 2359, 2362, 2364, 2368, 2371, 2373, 2377, 2380, 2381, 2383, 2386, 2391, 2393, 2395, 2401, 2403, 2404, 2410, 2417, 2418, 2426, 2428, 2429, 2435, 2436, 2439, 2440, 2442, 2444, 2449, 2464, 2465, 2469, 2474, 2479, 2483, 2486, 2488, 2490, 2496, 2501, 2503, 2506, 2511, 2512, 2513, 2516, 2519, 2522, 2527, 2529, 2530, 2531, 2533, 2538, 2544, 2548, 2554, 2556, 2559, 2561, 2565, 2566, 2567, 2578, 2579, 2581, 2590, 2593, 2597, 2601, 2609, 2610, 2612, 2613, 2622, 2623, 2624, 2625, 2628, 2629, 2631, 2632, 2633, 2634, 2636, 2638, 2643, 2644, 2645, 2646, 2654, 2655, 2656, 2658, 2663, 2666, 2667, 2668, 2669, 2670, 2671, 2673, 2674, 2682, 2686, 2690, 2698, 2699, 2701, 2702, 2718, 2721, 2722, 2728, 2729, 2736, 2740, 2741, 2742, 2745, 2751, 2752, 2760, 2761, 2763, 2764, 2767, 2768, 2770, 2781, 2784, 2785, 2788, 2791, 2792, 2795, 2796, 2801, 2802, 2804, 2807, 2808, 2812, 2814, 2819, 2822, 2825, 2826, 2827, 2830, 2850, 2858, 2865, 2871, 2873, 2880, 2883, 2884, 2887, 2888, 2894, 2896, 2897]
In [26]:
hsv_miss_lbp_correct_index = [  24,   91,  399,  497,  585,  656,  720,  731,  750,  769,  939,
       1037, 1087, 1335, 1340, 1467, 1687, 1763, 1797, 1894, 2174, 2180,
       2237, 2259, 2438, 2480, 2684, 2707, 2717, 2727, 2730, 2832]
In [27]:
sampling_miss_index = hsv_miss_lbp_correct_index
sampling_size = len(sampling_miss_index)
count_per_sample = 5
count_per_row = (count_per_sample+1)//2

for i in range(sampling_size):
    index = sampling_miss_index[i]
    sample = test_samples[index]
    f, axarr = plt.subplots(2, count_per_row)
    f.set_size_inches((20, 10))
    
    img = cv.imread(dataset.get_image_path(sample.query_obj.image_name))
    l, r, t, b = sample.query_obj.get_query_bbox()['body']
    cv.rectangle(img, (t,l), (b,r), (0,0,255), 2)
    axarr[0,0].imshow(img[:,:,[2,1,0]])
    axarr[0,0].set_title('id - {}'.format(index))
    
    for j in range(count_per_sample):
        img = cv.imread(dataset.get_image_path(sample.result[j]['image_name']))
        l, r, t, b  = sample.result[j]['bbox']
        cv.rectangle(img, (t,l), (b,r), (0,0,255), 2)
        
        row = (j+1) // count_per_row
        col = (j+1) % count_per_row
        
        axarr[row,col].imshow(img[:,:,[2,1,0]])
        axarr[row,col].set_title('{} - {}'.format(sample.result[j]['score'], sample.result[j]['correct']))
/opt/conda/lib/python3.6/site-packages/matplotlib/pyplot.py:524: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  max_open_warning, RuntimeWarning)
In [ ]: